import torch
import sys
import os.path as osp
import warnings
from transformers import StoppingCriteriaList
from .base import BaseModel


class MiniGPT4(BaseModel):

    INSTALL_REQ = True
    INTERLEAVE = False

    def __init__(
        self,
        mode="v2",
        root="/mnt/petrelfs/share_data/duanhaodong/MiniGPT-4/",
        temperature=1,
        max_out_len=512,
    ):

        if root is None:
            warnings.warn(
                "Please set root to the directory of MiniGPT-4, which is cloned from here: "
                "https://github.com/Vision-CAIR/MiniGPT-4. "
            )

        if mode == "v2":
            cfg = "minigptv2_eval.yaml"
        elif mode == "v1_7b":
            cfg = "minigpt4_7b_eval.yaml"
        elif mode == "v1_13b":
            cfg = "minigpt4_13b_eval.yaml"
        else:
            raise NotImplementedError

        self.mode = mode
        self.temperature = temperature
        self.max_out_len = max_out_len
        self.root = root
        this_dir = osp.dirname(__file__)

        self.cfg = osp.join(this_dir, "misc", cfg)
        sys.path.append(self.root)

        from omegaconf import OmegaConf
        from minigpt4.common.registry import registry
        from minigpt4.conversation.conversation import (
            StoppingCriteriaSub,
            CONV_VISION_Vicuna0,
            CONV_VISION_minigptv2,
        )

        device = torch.cuda.current_device()
        self.device = device

        cfg_path = self.cfg
        cfg = OmegaConf.load(cfg_path)

        model_cfg = cfg.model
        model_cfg.device_8bit = device
        model_cls = registry.get_model_class(model_cfg.arch)
        model = model_cls.from_config(model_cfg)
        model = model.to(device)
        model.eval()
        vis_processor_cfg = cfg.datasets.cc_sbu_align.vis_processor.train
        vis_processor = registry.get_processor_class(
            vis_processor_cfg.name
        ).from_config(vis_processor_cfg)
        self.model = model
        self.vis_processor = vis_processor

        self.CONV_VISION = (
            CONV_VISION_minigptv2 if self.mode == "v2" else CONV_VISION_Vicuna0
        )
        stop_words_ids = [[835], [2277, 29937]]
        stop_words_ids = [torch.tensor(ids).to(device) for ids in stop_words_ids]
        self.stopping_criteria = StoppingCriteriaList(
            [StoppingCriteriaSub(stops=stop_words_ids)]
        )

    def generate_inner(self, message, dataset=None):
        from minigpt4.conversation.conversation import Chat

        prompt, image_path = self.message_to_promptimg(message, dataset=dataset)
        if self.mode == "v2":
            chat = Chat(self.model, self.vis_processor, device=self.device)
        else:
            chat = Chat(
                self.model,
                self.vis_processor,
                device=self.device,
                stopping_criteria=self.stopping_criteria,
            )

        chat_state = self.CONV_VISION.copy()
        img_list = []
        _ = chat.upload_img(image_path, chat_state, img_list)
        chat.encode_img(img_list)
        chat.ask(prompt, chat_state)
        with torch.inference_mode():
            msg = chat.answer(conv=chat_state, img_list=img_list)[0]
        return msg
